package com.garbagecode.resultbounds;

import java.util.List;
import java.util.Properties;

import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ResultMap;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;

import com.garbagecode.resultbounds.dialect.Dialect;
import com.garbagecode.resultbounds.dialect.MySqlDialect;

/**
 * 
 * 要使用这个分页插件，需要在 MyBatis 中如此配置
 * <configuration>
 * ...
 *   <plugins>
 *     <plugin interceptor="com.garbagecode.resultbounds.PaginationInterceptor">
 *       <!-- 这个 value 里面的值可以替换成自己实现的 Dialect 类 -->
 *       <property name="dialect" value="com.garbagecode.resultbounds.dialect.MySqlDialect" />
 *     </plugin>
 *   </plugins>
 * ...
 * <configuration>
 */
@Intercepts({ @Signature(type = Executor.class, 
                         method = "query", 
                         args = { MappedStatement.class, 
                                  Object.class,
                                  RowBounds.class, 
                                  ResultHandler.class }) })
public class PaginationInterceptor implements Interceptor {
  private Dialect dialect = new MySqlDialect(); // 用于生成分页语句和统计记录总数
  private static MyBatisClassCreator  mybatisClassCreator = new MyBatisClassCreator();
  private static List<ResultMap> totalCountResultMaps = mybatisClassCreator.createResultMaps(Long.class); // 获取统计结果的时候用到
  public static final String PROPERTY_DIALECT_NAME = "dialect";
  
  protected Dialect getDialect() {
    return dialect;
  }
  
  /**
   * 
   * 分页的核心代码
   * 
   */
  @SuppressWarnings("rawtypes")
  @Override
  public Object intercept(Invocation invocation) throws Throwable {
    Executor executor = (Executor) invocation.getTarget();
    Object[] args = invocation.getArgs();
    MappedStatement mappedStatement = (MappedStatement) args[0];
    Object parameter = args[1];
    RowBounds rowBounds = (RowBounds) args[2];

    if (rowBounds == null) {
      return invocation.proceed();
    }
    
    // 若 rowBounds 是 TotalCount 类型，表示这次查询语句是要统计符合查询条件的记录总数
    if (rowBounds instanceof TotalCount) {
      return getTotal(executor, args, mappedStatement, parameter);
    }

    // 若 rowBounds 是 ResultBounds 类型，那么初始它的 totalCount 属性，
    // 这样之后就可以通过 ResultBoudns.getTotal 获取符合查询条件的记录总数
    if (rowBounds instanceof ResultBounds) {
      setTotal((ResultBounds) rowBounds, mappedStatement, parameter, invocation);
    }
    
    // 把 rowBounds 转换成 ResultBounds 类型是为了接下来的代码会好写些
    ResultBounds resultBounds = rowBounds instanceof ResultBounds 
        ? (ResultBounds) rowBounds 
        : new ResultBounds(rowBounds);
    
    // 判断是否需要分页，不需要分页的话，那就不必执行花括号里的代码
    if (resultBounds.getOffset() != RowBounds.NO_ROW_OFFSET || 
        resultBounds.getLimit() != RowBounds.NO_ROW_LIMIT || 
        (resultBounds.getOrder() != null && resultBounds.getOrder().trim().length() > 0)) {
      //
      // 分页的原理是创建一些改动过的对象，这些对象里的 SQL 语句、返回结果类型被
      // 替换成要分页需要设的值，然后把这些对象作为参数调用 executor.query 方法
      //
      // 步骤 1: 获取 BoundSql 对象，之后要从这个对象里获取原 SQL 语句
      BoundSql boundSql = mappedStatement.getBoundSql(parameter);
      
      // 步骤2: 获取生成的分页 SQL 语句
      String paginationSql = dialect.getPaginationSql(
          boundSql.getSql(), 
          resultBounds.getOffset(),
          resultBounds.getLimit(), 
          resultBounds.getOrder());
      
      // 步骤3: 创建一个 SqlSource 对象，这个对象里的 SQL 语句就是步骤 2 的分页 SQL 语句
      SqlSource temporarySqlSource = mybatisClassCreator.createSqlSource(
          mappedStatement.getConfiguration(), 
          paginationSql,
          boundSql.getParameterMappings(), 
          boundSql.getParameterObject());
      
      // 步骤4: 创建一个 MappedStatement 对象，这个对象里的字段也是动过手脚的
      MappedStatement newMappedStatement = mybatisClassCreator.createMappedStatement(mappedStatement, temporarySqlSource, null);

      // 步骤5: 用动过手脚的参数调用这个方法返回的内容就是分页后的记录
      return executor.query(newMappedStatement, parameter, RowBounds.DEFAULT, (ResultHandler) args[3]);
    }

    return invocation.proceed();
  }
  
  /**
   * 返回符合查询条件的记录总数
   * @param executor
   * @param args
   * @param mappedStatement
   * @param parameter
   * @return
   * @throws Throwable
   */
  @SuppressWarnings("rawtypes")
  protected Object getTotal(Executor executor, 
                            Object[] args, 
                            MappedStatement mappedStatement, 
                            Object parameter) throws Throwable {
    // 步骤1: 获取 BoundSql 对象，之后要从这个对象里获取原 SQL 语句
    BoundSql boundSql = mappedStatement.getBoundSql(parameter);
    // 步骤2: 获取生成的用于统计记录的 SQL 语句
    String totalCountSql = dialect.getTotalCountSql(boundSql.getSql());
    
    // 步骤3: 创建一个 SqlSource 对象，这个对象里的 SQL 语句就是步骤 2 的统计 SQL 语句
    SqlSource newSqlSource = mybatisClassCreator.createSqlSource(
        mappedStatement.getConfiguration(), 
        totalCountSql, 
        boundSql.getParameterMappings(),
        boundSql.getParameterObject());

    // 步骤4: totalCountResultMaps 好像是确定查询要返回的类型，这里要求查询返回 Long 类型
    MappedStatement newMappedStatement = mybatisClassCreator.createMappedStatement(mappedStatement, 
                                                                                   newSqlSource, 
                                                                                   totalCountResultMaps);

    // 步骤5: 用动过手脚的参数调用这个方法返回的内容就是满足查询条件的记录总数
    return executor.query(newMappedStatement, parameter, RowBounds.DEFAULT, (ResultHandler) args[3]);
  }
  
  /**
   * 设置符合查询条件的记录总数
   * @param resultBounds
   * @param mappedStatement
   * @param parameter
   * @param invocation
   */
  protected void setTotal(ResultBounds resultBounds, 
                          MappedStatement mappedStatement, 
                          Object parameter, 
                          Invocation invocation) {
    resultBounds.setTotalCount(new TotalCount(mappedStatement.getId(), parameter));
  }
  
  @Override
  public Object plugin(Object target) {
    return Plugin.wrap(target, this);
  }

  @Override
  public void setProperties(Properties properties) {
    String dialectName = properties.getProperty(PROPERTY_DIALECT_NAME);

    //
    // 初始属性 dialect
    //
    if (dialectName != null) {
      try {
        dialect = (Dialect) Class.forName(dialectName.trim()).newInstance();
      } catch (InstantiationException | IllegalAccessException | ClassNotFoundException e) {
        throw new RuntimeException(e);
      }
    }
  }
  
}
